#include <math.h>
#include <mpi.h>
#include "cell.h"
#include "comm.h"
INLINE real normal() {
  real u = 1 - rand() * (1.0 / RAND_MAX);
  real v = 1 - rand() * (1.0 / RAND_MAX);
  real r = sqrt(-2 * log(u));
  real theta = 2 * M_PI * v;
  return r * sin(theta);
}
void create_velocity(cellgrid_t *grid, real temp, mpp_t *mpp){
  real factor = sqrt(M_K * temp * M_NA * 1000);
  real kin = 0;
  long natoms = 0;
  FOREACH_LOCAL_CELL(grid, x, y, z, cell) {
    for (int i = 0; i < cell->natom; i ++){
      real vx = normal() * factor / sqrt(cell->mass[i]) * 1e-5;
      real vy = normal() * factor / sqrt(cell->mass[i]) * 1e-5;
      real vz = normal() * factor / sqrt(cell->mass[i]) * 1e-5;
      kin += cell->mass[i] * (vx * vx + vy * vy + vz * vz);
      cell->v[i].x = vx;
      cell->v[i].y = vy;
      cell->v[i].z = vz;
      natoms ++;
    }
  }
  real kin_jole = kin * 0.5 / (M_NA * 1000) * 1e10;
  real send_buf[2], recv_buf[2];
  real natomsf = natoms;

  comm_vallreduce(2, mpi_real, MPI_SUM, mpp, &natomsf, &kin_jole);
  
  real t0 = kin_jole * 2 / (M_K * natomsf * 3);
  // printf("%f\n", t0);
  real tscale = sqrt(temp / t0);
  FOREACH_LOCAL_CELL(grid, x, y, z, cell) {
    for (int i = 0; i < cell->natom; i ++) {
      cell->v[i].x *= tscale;
      cell->v[i].y *= tscale;
      cell->v[i].z *= tscale;
      // cell->v[i].x -= 0.1;
    }
  }
}